feat(onnx): bayesflow → ONNX exporter (transform_bayesflow_to_onnx)#80
Open
AlexanderFengler wants to merge 1 commit into
Open
feat(onnx): bayesflow → ONNX exporter (transform_bayesflow_to_onnx)#80AlexanderFengler wants to merge 1 commit into
AlexanderFengler wants to merge 1 commit into
Conversation
Adds lanfactory.onnx.transform_bayesflow_to_onnx, the bayesflow sibling of transform_sbi_to_onnx (PR #79). Wraps a trained bayesflow ContinuousApproximator (NLE) or RatioApproximator (NRE) and writes a single-trial ONNX file consumable by HSSM's loglik_kind="approx_differentiable" path. Same I/O contract as the sbi exporter (rank-1 input [theta..., x...], rank-0 scalar log-likelihood, opset 17) so HSSM ingests both via the same loglik="*.onnx" gesture with zero HSSM-side changes. What's in this commit - src/lanfactory/onnx/bayesflow.py: exporter module mirroring sbi.py. Contains _BayesflowNLELogProbWrapper and _BayesflowNRELogRatioWrapper. Pre-evaluates the bayesflow Standardize layer's moving mean/std to torch buffer constants at wrapper construction time so the ONNX trace is fully static (avoids If, Size, Tile dynamic-shape ops that jaxonnxruntime can't run). Guards on KERAS_BACKEND=torch and identity Adapter; both raise actionable errors with concrete fix hints. - src/lanfactory/onnx/__init__.py: export the new function. - pyproject.toml: add [bayesflow] optional extra (bayesflow>=2.0.8, keras>=3.12), add to [all] and [dev]. Also refactors the existing sbi+nflows pair into its own [sbi] extra (mirroring the new [bayesflow]) while keeping them in [all]. - tests/test_bayesflow_nle_export.py: 6 tests. Three-way numerical agreement (torch reference wrapper <-> onnxruntime <-> jaxonnxruntime) at atol=1e-5, gradient agreement at atol=1e-4, log-prob ordering sanity, and three guard tests (wrong backend, non-identity adapter, wrong mode). - tests/test_bayesflow_nre_export.py: 4 tests. Same shape for the NRE path on a RatioApproximator. - tests/test_bayesflow_hssm_integration.py: end-to-end DDM smoke (pytest.importorskip("hssm")). Mirrors test_sbi_hssm_integration.py. - docs/exporting_bayesflow_models.md: full constraint catalog (KERAS_BACKEND, CouplingFlow knobs, silu vs hard_silu activation choice, identity-adapter requirement, JAX x64). Quick-starts for NLE and NRE. "Two paths into HSSM" framing alongside the JAX-callable path used in bayesflow_lre_integration.ipynb. v1 constraints (documented, enforced where introspectable) User must train with: - permutation=None (FixedPermutation -> aten::ravel, unsupported) - use_actnorm=False (untested in v1) - transform=AffineTransform(clamp=False) explicit instance (find_transform("affine") drops kwargs - bayesflow upstream bug) - subnet_kwargs.activation="silu" or another smooth activation (default hard_silu emits HardSwish, no jaxonnxruntime handler) - identity Adapter (numpy-only adapter ops cannot be baked into ONNX) Bayesflow continuous observations only. MNLE-style discrete + continuous deferred until upstream MNLE support lands. Numerical guarantees 19 passing tests across both bayesflow and sbi tracks; no regressions on the existing sbi exporter. Each export is verified for three-way numerical agreement at 1e-5 and gradient agreement at 1e-4. Companion PRs - HSSM: docs(tutorials): add bayesflow_nle_onnx_integration.ipynb on a fresh bayesflow-integration branch off main (sibling, not child, of the sbi-integration branch in PR #964). - HSSMSpine: bayesflow-onnx-integration.md design doc + upstream-bugs-from-bayesflow-onnx-work.md catalog of upstream defects surfaced during this work (jaxonnxruntime missing HardSwish/Size handlers; bayesflow find_transform kwarg-drop bug; bayesflow global torch.set_grad_enabled(False) cross-library leak; torch.onnx missing aten::ravel/asinh symbolic registrations). This branch is stacked on sbi-connector (PR #79). When #79 merges, this PR's base auto-retargets to main. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Pull request overview
Adds a BayesFlow-to-ONNX exporter so BayesFlow-trained NLE/NRE models can be consumed by HSSM via the same loglik="model.onnx", loglik_kind="approx_differentiable" pathway already used for LAN and sbi exports.
Changes:
- Introduces
lanfactory.onnx.transform_bayesflow_to_onnxwith NLE/NRE wrappers that bake BayesFlow standardization statistics into a static ONNX trace. - Adds BayesFlow-focused regression tests (three-way numerical agreement + gradient agreement) and an optional HSSM integration smoke test.
- Adds a
bayesflowoptional extra (and a dedicatedsbiextra) plus a new documentation guide for BayesFlow exports.
Reviewed changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
src/lanfactory/onnx/bayesflow.py |
New BayesFlow exporter implementation (NLE + NRE wrappers + ONNX export entry point). |
src/lanfactory/onnx/__init__.py |
Exposes transform_bayesflow_to_onnx at the package level. |
pyproject.toml |
Adds bayesflow and sbi extras; extends all and dev deps. |
tests/test_bayesflow_nle_export.py |
New NLE export tests: forward/grad agreement + guardrails. |
tests/test_bayesflow_nre_export.py |
New NRE export tests: forward/grad agreement + guardrails. |
tests/test_bayesflow_hssm_integration.py |
New (skip-if-missing-HSSM) end-to-end BayesFlow→ONNX→HSSM integration test. |
docs/exporting_bayesflow_models.md |
New BayesFlow ONNX export guide and constraint catalog. |
uv.lock |
Dependency lockfile updates to include BayesFlow/Keras and related resolutions. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+32
to
+34
| import pandas as pd # noqa: E402 | ||
|
|
||
| import bayesflow as bf # noqa: E402 |
Comment on lines
+1
to
+10
| # Exporting bayesflow-trained networks to ONNX | ||
|
|
||
| LANfactory's [`transform_bayesflow_to_onnx`](api/onnx.md) is the bayesflow | ||
| sibling of [`transform_sbi_to_onnx`](exporting_sbi_models.md). It wraps a | ||
| trained [`bayesflow`](https://github.com/bayesflow-org/bayesflow) | ||
| `ContinuousApproximator` (NLE) or `RatioApproximator` (NRE) and writes a | ||
| single-trial ONNX file that HSSM's `loglik_kind="approx_differentiable"` | ||
| path can consume exactly like an sbi export. Same user gesture, same file | ||
| format, same HSSM-side loader — regardless of which training framework you | ||
| came from. |
Comment on lines
+15
to
+16
| os.environ.setdefault("KERAS_BACKEND", "torch") | ||
| os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu") |
| # why; same reasoning here. | ||
| import os | ||
|
|
||
| os.environ.setdefault("KERAS_BACKEND", "torch") |
Comment on lines
+22
to
+23
| os.environ.setdefault("KERAS_BACKEND", "torch") | ||
| os.environ.setdefault("KERAS_TORCH_DEVICE", "cpu") |
| @@ -82,6 +86,8 @@ dev = [ | |||
| "jaxonnxruntime>=0.3", | |||
| "onnxruntime>=1.17", | |||
| "nflows>=0.14", | |||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
lanfactory.onnx.transform_bayesflow_to_onnx— the bayesflow sibling oftransform_sbi_to_onnxfrom #79. Wraps a trainedbayesflow.ContinuousApproximator(NLE) orRatioApproximator(NRE) and writes a single-trial ONNX file consumable by HSSM'sloglik_kind="approx_differentiable"path.The user gesture becomes the same regardless of training framework:
hssm.HSSM(loglik="model.onnx", loglik_kind="approx_differentiable")works for LAN, sbi, and now bayesflow exports through the identical HSSM-side code path (no HSSM changes required).Branch relationship
This branch is stacked on
sbi-connector(#79). The diff in this PR is exactly the bayesflow-specific additions; #79's commits are not duplicated. When #79 merges tomain, GitHub will offer to auto-retarget this PR's base tomain.What's in this PR
src/lanfactory/onnx/bayesflow.pytransform_bayesflow_to_onnx)src/lanfactory/onnx/__init__.pypyproject.toml[bayesflow]optional extra; refactor existing sbi+nflows pair into a new[sbi]extra (symmetric with[bayesflow]); both added to[all]and[dev]tests/test_bayesflow_nle_export.pytests/test_bayesflow_nre_export.pytests/test_bayesflow_hssm_integration.pypytest.importorskip("hssm"), mirrorstest_sbi_hssm_integration.py)docs/exporting_bayesflow_models.mdexporting_sbi_models.mdArchitectural contract
Same I/O contract as the sbi exporter:
(theta_dim + x_dim,), parameters first then observations.jaxonnxruntimereproducibility.Key implementation choice: the wrapper bakes the bayesflow
Standardizelayer's accumulatedmoving_mean/moving_stdas torch buffer constants at construction time. This sidesteps the dynamic-shape ops (If,Size,Tile) that the live Keras layer would emit at trace time —jaxonnxruntimedoesn't have aSizehandler. The constants are correct because training is complete by export time.v1 constraints (documented, enforced where introspectable)
User must train with:
permutation=None(FixedPermutation →aten::ravel, unsupported in opset 17/20)use_actnorm=False(untested in v1)transform=AffineTransform(clamp=False)as an explicit instance (find_transform("affine")silently dropstransform_kwargs— bayesflow upstream bug, catalogued in companion HSSMSpine PR)subnet_kwargs.activation="silu"(defaulthard_siluexports as the fused ONNX opHardSwish, no jaxonnxruntime handler; silu decomposes toSigmoid + Mul)Adapter(numpy-only Adapter ops can't be baked into ONNX)Each violation produces an actionable error message at export time.
Test status
torch.set_grad_enabled(True)after importing bayesflow to undo the global autograd disable that bayesflow's torch backend does at import time. This is a known upstream issue documented in the companion HSSMSpine PR's upstream-bugs catalog.Companion PRs
bayesflow-integrationbranch (new), addsdocs/tutorials/bayesflow_nle_onnx_integration.ipynb. Sibling, not child, ofsbi-integration(#964) — works against HSSMmainstandalone (Part 1 includes the manualjaxort_only_allow_initializers_as_static_args=Falseworkaround that #964 plans to auto-handle insideonnx2jax).bayesflow-onnx-plansbranch (stacked on NameError when wandb is not found #9 for the cross-reference tosbi-onnx-integration.md). Adds the design doc and an upstream-bugs catalog covering the five real upstream defects surfaced during this work.Test plan
pytest.importorskip("hssm")) once both packages can be installed in the same env🤖 Generated with Claude Code